{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook Setup \n",
    "The following cell will install Drake, checkout the underactuated repository, and set up the path (only if necessary).\n",
    "- On Google's Colaboratory, this **will take approximately two minutes** on the first time it runs (to provision the machine), but should only need to reinstall once every 12 hours.  Colab will ask you to \"Reset all runtimes\"; say no to save yourself the reinstall.\n",
    "- On Binder, the machines should already be provisioned by the time you can run this; it should return (almost) instantly.\n",
    "\n",
    "More details are available [here](http://underactuated.mit.edu/drake.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "  import pydrake\n",
    "  import underactuated\n",
    "except ImportError:\n",
    "  !curl -s https://raw.githubusercontent.com/RussTedrake/underactuated/master/scripts/setup/jupyter_setup.py > jupyter_setup.py\n",
    "  from jupyter_setup import setup_underactuated\n",
    "  setup_underactuated()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Value Iteration using Neural Networks as a Function Approximator\n",
    "\n",
    "In this notebook, we'll use [PyTorch](https://pytorch.org/tutorials/) to implement a basic fitted value iteration algorithm using neural networks.\n",
    "\n",
    "Let's start by setting up the double integrator plant, the neural network architecture, and two potential cost functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython import get_ipython\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib import cm\n",
    "from pydrake.all import DiscreteAlgebraicRiccatiEquation\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from underactuated.jupyter import SetupMatplotlibBackend\n",
    "plt_is_interactive = SetupMatplotlibBackend()\n",
    "\n",
    "# Define the double integrator\n",
    "A = torch.tensor([[0., 1.], [0., 0.]])\n",
    "B = torch.tensor([[0.], [1.]])\n",
    "At = A.transpose(0, 1)\n",
    "Bt = B.transpose(0, 1)\n",
    "\n",
    "# Define the function approximator for J\n",
    "class Net(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # Linear implements y = xA^T + b\n",
    "        self.fc1 = nn.Linear(2, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    def num_flat_features(self, x):\n",
    "        size = x.size()[1:]  # all dimensions except the batch dimension\n",
    "        num_features = 1\n",
    "        for s in size:\n",
    "            num_features *= s\n",
    "        return num_features\n",
    "\n",
    "# Define the cost function\n",
    "def min_time_cost(xt, ut):\n",
    "    at_goal = torch.isclose(xt, torch.zeros(1,2))\n",
    "    # cost = 1 if ~at_goal * [1;1] >= 1, 0 otherwise.\n",
    "    return torch.min((~at_goal).float().matmul(torch.ones(2,1)), torch.ones(1))\n",
    "\n",
    "Q = torch.eye(2)\n",
    "R = torch.eye(1)\n",
    "def quadratic_regulator_cost(xt, ut):\n",
    "    return xt.matmul(Q.matmul(xt.transpose(-2,-1))) + ut.matmul(R.matmul(ut.transpose(-2,-1)))\n",
    "\n",
    "BRinv = B*R.inverse()\n",
    "\n",
    "def min_time_solution(xt):\n",
    "  # Caveat: this does not take the time discretization (zero-order hold on u) into account.\n",
    "  q = xt[:,:,0]\n",
    "  qdot = xt[:,:,1]\n",
    "  # mask indicates that we are in the regime where u = +1.\n",
    "  mask = ((qdot < 0) & (2*q <= qdot.pow(2))) | ((qdot >= 0) & (2*q < -qdot.pow(2)))\n",
    "  T = torch.empty(q.size())\n",
    "  T[mask] = 2*(.5*(qdot[mask].pow(2)) - q[mask]).sqrt() - qdot[mask]\n",
    "  T[~mask] = qdot[~mask] + 2*(.5*(qdot[~mask].pow(2)) + q[~mask]).sqrt()\n",
    "  return T.unsqueeze(-1)\n",
    "  \n",
    "def quadratic_regulator_solution(xt, timestep):\n",
    "  S = DiscreteAlgebraicRiccatiEquation(A=(np.eye(2)+timestep*A.numpy()),\n",
    "                                       B=timestep*B.numpy(),\n",
    "                                       Q=Q, R=R)\n",
    "  return xt.matmul(torch.from_numpy(S).float().matmul(xt.transpose(-2,-1)))\n",
    "\n",
    "def plot_and_compare(net, running_cost, timestep):\n",
    "  x1s = torch.linspace(-3,3,31)\n",
    "  x2s = torch.linspace(-3,3,51)\n",
    "  X1s, X2s = torch.meshgrid(x1s, x2s)\n",
    "  X = torch.stack((X1s.flatten(), X2s.flatten()), 1).unsqueeze(1)\n",
    "  \n",
    "  with torch.no_grad():\n",
    "    J = net.forward(X)\n",
    "\n",
    "  fig = plt.figure(figsize=(9, 4))\n",
    "  ax1, ax2 = fig.subplots(1, 2, subplot_kw=dict(projection='3d'))\n",
    "  ax1.set_xlabel(\"q\")\n",
    "  ax1.set_ylabel(\"qdot\")\n",
    "  ax1.set_title(\"Estimated Cost-to-Go\")\n",
    "  ax1.plot_surface(X1s, X2s, J.view(X1s.size()).detach().numpy(), rstride=1, cstride=1, cmap=cm.jet)\n",
    "  \n",
    "  if running_cost == min_time_cost:\n",
    "    Jd = min_time_solution(X)\n",
    "  elif running_cost == quadratic_regulator_cost:\n",
    "    Jd = quadratic_regulator_solution(X, timestep)\n",
    "\n",
    "  ax2.set_xlabel(\"q\")\n",
    "  ax2.set_ylabel(\"qdot\")\n",
    "  ax2.set_title(\"Analytical Cost-to-Go\")\n",
    "  ax2.plot_surface(X1s, X2s, Jd.view(X1s.size()).detach().numpy(), rstride=1, cstride=1, cmap=cm.jet)\n",
    "    \n",
    "  # Score is worst absolute different (e.g. infinity-norm) of the samples\n",
    "  criterion = nn.MSELoss()\n",
    "  score = criterion(J, Jd).item()\n",
    "  print(\"MSE(Ĵᵢ,Jᵢ) = %.2f\" % score)   \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Discrete time, continuous state, discrete action\n",
    "\n",
    "This is the standard \"fitted value iteration\" algorithm with a torch network as the function approximator, and a single step of gradient descent performed on each iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve(net, optimizer, running_cost):\n",
    "  criterion = nn.MSELoss()\n",
    "  timestep = 0.1\n",
    "\n",
    "  x1s = torch.linspace(-3,3,31)\n",
    "  x2s = torch.linspace(-3,3,51)\n",
    "  us = torch.linspace(-1,1,9)\n",
    "\n",
    "  X1s, X2s = torch.meshgrid(x1s, x2s)\n",
    "  # Want x as batch row vectors... size [num_state_samples, 1, num_states]\n",
    "  # (because the linear units in net expect row vectors)\n",
    "  X = torch.stack((X1s.flatten(), X2s.flatten()), 1).unsqueeze(1)\n",
    "\n",
    "  X1s, X2s, Us = torch.meshgrid(x1s, x2s, us)\n",
    "  # XwithU has size [num_state_samples, num_input_samples, 1, num_states]\n",
    "  # UwithX has size [num_state_samples, num_input_samples, 1, num_inputs]\n",
    "  XwithU = torch.stack((X1s.flatten(0,1), X2s.flatten(0,1)), 2).unsqueeze(2)\n",
    "  UwithX = Us.flatten(0,1).unsqueeze(-1).unsqueeze(-1)\n",
    "\n",
    "  Xnext = XwithU + timestep * (XwithU.matmul(At) + UwithX.matmul(Bt))\n",
    "  G = timestep*running_cost(XwithU, UwithX)\n",
    "\n",
    "  target_net = Net()\n",
    "  for epoch in range(1000 if get_ipython() else 2):\n",
    "    net.zero_grad()\n",
    "    target_net.load_state_dict(net.state_dict())\n",
    "    with torch.no_grad():\n",
    "      Jnext = target_net.forward(Xnext)\n",
    "      Jd, ind = torch.min(G + Jnext, dim=1)\n",
    "    J = net.forward(X)\n",
    "    loss = criterion(J, Jd)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if epoch % 20 == 19:\n",
    "      print('[%d] loss: %.3f' % (epoch + 1, loss.item()))\n",
    "\n",
    "  plot_and_compare(net, running_cost, timestep)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's apply it to the quadratic-regulator cost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(12345)  # for scoring\n",
    "net = Net()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01)\n",
    "solve(net, optimizer, quadratic_regulator_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and the minimum-time cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(12345)  # for scoring\n",
    "net = Net()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.0015)\n",
    "solve(net, optimizer, min_time_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a similar version, but with states chosen at random on each iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve(net, optimizer, running_cost):\n",
    "  criterion = nn.MSELoss()\n",
    "  timestep = 0.1\n",
    "\n",
    "  us = torch.linspace(-1,1,9)\n",
    "  num_state_samples = 1000\n",
    "  XwithU = torch.empty((num_state_samples, us.size()[-1], 1, 2))\n",
    "  UwithX = torch.empty((num_state_samples, us.size()[-1], 1, 1))\n",
    "\n",
    "  target_net = Net()\n",
    "  for epoch in range(1000 if get_ipython() else 2):\n",
    "    X = 6*torch.rand((num_state_samples, 1, 2)) - 3.\n",
    "    X[0, :, :] = torch.zeros(1,2) # make sure zero appears\n",
    "\n",
    "    for i in range(us.size()[-1]):\n",
    "      XwithU[:, i, :, :] = X\n",
    "      UwithX[:, i, :, :] = us[i]\n",
    "\n",
    "    Xnext = XwithU + timestep * (XwithU.matmul(At) + UwithX.matmul(Bt))\n",
    "    G = timestep*running_cost(XwithU, UwithX)\n",
    "\n",
    "    net.zero_grad()\n",
    "    target_net.load_state_dict(net.state_dict())\n",
    "    with torch.no_grad():\n",
    "      Jnext = target_net.forward(Xnext)\n",
    "      Jd, ind = torch.min(G + Jnext, dim=1)\n",
    "    J = net.forward(X)\n",
    "    loss = criterion(J, Jd)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if epoch % 20 == 19:\n",
    "      print('[%d] loss: %.6f' % (epoch + 1, loss.item()))\n",
    "\n",
    "  plot_and_compare(net, running_cost, timestep)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(12345)  # for scoring\n",
    "net = Net()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01)\n",
    "solve(net, optimizer, quadratic_regulator_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(12345)  # for scoring\n",
    "net = Net()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.0015)\n",
    "solve(net, optimizer, min_time_cost)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# My challenge to you\n",
    "\n",
    "I am intensely interested in knowing how well these representations can work, compared to the known analytical solutions, but have so far not spent any time tuning the architecture nor learning parameters.  Can you do better? If you can make a substantial improvement to this example, I would love to incorporate your updates here, and record your contributions.\n",
    "\n",
    "While there are no precise rules to this game, ideally I would like to see the same network architecture work for both problems, and a solution that does not bake in any information about the double integrator (e.g. it should work almost immediately on the pendulum, and with only a small amount of changes for the acrobot, cart-pole, etc).\n",
    "\n",
    "<style>\n",
    "th {\n",
    "    font-weight:bold;\n",
    "}\n",
    "</style>    \n",
    "\n",
    "<table style=\"width:90%; .td { text-align:left; };\">\n",
    "    <tr><th style=\"text-align:left\">Contributor</th><th>MSE: Quadratic Regulator</th><th>MSE: Minimum-time problem</th><th>What changed?</th></tr>\n",
    "    <tr><td style=\"text-align:left\">Russ Tedrake</td><td>13928.50</td><td>6.37</td><td>Initial Example</td></tr>\n",
    "</table>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
